import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from collections import Counter
import re  # regex parsing
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="action_annotation.json", help="Path to action_annotation.json for sampling distractor actions")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Updated persona prompts in Reason/Answer format
# ---------------------------------------------------------------------------
persona_prompts = {
    "18-24_female": """You are a woman aged 18–24 who intuitively understands what resonates with your generation—bold aesthetics, authenticity, humor, pop culture references, and individuality.

You will be shown (1) the STORY of a video advertisement and (2) a LIST OF ACTIONS that a viewer might take after watching it.

Choose the SINGLE best action.

Return EXACTLY two lines:
Answer: <action>
Reason: <brief justification>""",

    "18-24_male": """You are a man aged 18–24 who knows what grabs young men's attention—humor, edge, cultural references, and visual flair.

You will be shown the STORY of a video advertisement and a LIST OF ACTIONS that viewers might consider.

Pick the single best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "25-34_female": """You are a woman aged 25–34 who connects with content that is visually refined, emotionally resonant, and aligned with lifestyle interests—career, wellness, and relationships.

Given the STORY and a LIST OF ACTIONS, select the single best action that fits the story.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "25-34_male": """You are a man aged 25–34 who appreciates content that shows ambition, clarity, innovation, fitness, and smart humor.

Pick ONE action from the provided list that best reflects what a viewer should do.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "35-44_female": """You are a woman aged 35–44 who is drawn to emotionally intelligent storytelling, depth, and purpose.

Choose the single best action from the list.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "35-44_male": """You are a man aged 35–44 who connects with grounded, aspirational content about family, success, and purpose.

Pick the best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "45-54_female": """You are a woman aged 45–54 who appreciates visuals and stories that carry meaning, clarity, and purpose.

Select one best action from the list.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "45-54_male": """You are a man aged 45–54 who values storytelling that emphasizes responsibility, growth, trust, and wisdom.

Choose the best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "55+_female": """You are a woman aged 55 or older who resonates with content that conveys warmth, legacy, and deep emotional meaning.

Pick the single best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",

    "55+_male": """You are a man aged 55 or older who prefers storytelling with sincerity, meaning, and timeless values.

Select one best action.

Return exactly two lines:
Answer: <action>
Reason: <brief justification>""",
}



# CSV must contain 'reasons' column (JSON array or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }

    # --------------------------------------------------------------
    # Load Qwen chat model once
    # --------------------------------------------------------------
    global model, tokenizer
    model_name = "Qwen/Qwen3-32B"  #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # -------------------------------------------------------------------
    # Load full action annotation to draw distractor actions
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all actions into a single pool we can sample from
    all_actions_pool = [act for acts in annotation_data.values() for act in acts]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve ground-truth actions for this video
            # Priority: (1) annotation file; (2) CSV column 'reasons'
            # -------------------------------------------------------------------

            correct_actions = []

            # 1) Try annotation JSON
            if video_id in annotation_data:
                correct_actions = annotation_data[video_id]

            # 2) Fallback to CSV column if still empty
            if not correct_actions:
                actions_raw = rec.get('reasons', '')  # legacy column name
                try:
                    correct_actions = json.loads(actions_raw) if isinstance(actions_raw, str) else actions_raw
                except Exception:
                    correct_actions = [r.strip() for r in str(actions_raw).split(';') if r.strip()]

            # Ensure list and filter blanks
            if isinstance(correct_actions, str):
                correct_actions = [correct_actions]
            correct_actions = [r for r in correct_actions if r]

            if not correct_actions:
                print(f"No actions found for id {video_id}; skipping")
                continue

            # -------------------------------------------------------------------
            # Build candidate list: 5 correct + 25 random distractors
            # -------------------------------------------------------------------
            distractor_pool = [a for a in all_actions_pool if a not in correct_actions]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_actions = random.sample(distractor_pool, num_distractors)

            candidate_actions = correct_actions + distractor_actions
            random.shuffle(candidate_actions)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            persona_predictions = {}
            for persona_name, sys_prompt in persona_prompts.items():
                messages = [
                    {"role": "system", "content": sys_prompt},
                    {
                        "role": "user",
                        "content": f"Story:\n{cleaned_text}\n\nList of actions:\n" + "\n".join(f"{i+1}. {a}" for i, a in enumerate(candidate_actions)) + "\n\nReturn exactly two lines:\nAnswer: <action>\nReason: <brief justification>"
                    }
                ]
                
                try:
                    # Qwen inference
                    input_ids = tokenizer.apply_chat_template(
                        messages,
                        tokenize=True,
                        add_generation_prompt=True,
                        return_tensors="pt",
                        enable_thinking=False,
                    ).to(model.device)

                    with torch.no_grad():
                        outputs = model.generate(
                            input_ids=input_ids,
                            max_new_tokens=300,
                            temperature=0.85,
                            do_sample=True,
                            min_p=0.1,
                        )

                    raw_resp = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True).strip()
 
                    # Parse Answer line
                    ans_match = re.search(r"(?i)^answer:\s*(.+)$", raw_resp, re.MULTILINE)
                    chosen = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                    # If answer is digit index map accordingly
                    if chosen.isdigit():
                        idx_int = int(chosen)
                        if 1 <= idx_int <= len(candidate_actions):
                            chosen = candidate_actions[idx_int-1]

                    # Extract optional justification
                    reason_match = re.search(r"(?i)^reason:\s*(.+)$", raw_resp, re.MULTILINE)
                    justif = reason_match.group(1).strip() if reason_match else ""

                    persona_predictions[persona_name] = {
                        'action': chosen,
                        'explanation': justif,
                        'raw': raw_resp,
                    }
                except Exception as e:
                    print(f"Error during OpenAI call for key {video_id}, persona {persona_name}: {e}")
                    persona_predictions[persona_name] = "error"

            # Majority vote for the final topic
            if persona_predictions:
                # Collect topics excluding errors
                valid_preds = [p['action'] for p in persona_predictions.values() if p['action'] != "error"]
                if valid_preds:
                    final_action = Counter(valid_preds).most_common(1)[0][0]
                else:
                    final_action = "error_no_valid_predictions"
            else:
                final_action = "error_no_predictions"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'persona_predictions': persona_predictions,
                'final_action': final_action,
                'candidate_actions': candidate_actions,
                'correct_actions': correct_actions,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




